import tensorflow as tf
import os
import numpy as np
import json
import abc
import PIL.Image as Image


class HandDetectionDataset:
    def __init__(self, dataset_path, shuffle=True, num_classes=3):
        assert os.path.exists(dataset_path)
        assert os.path.exists(dataset_path+'/image')
        assert os.path.exists(dataset_path+'/label')
        
        self.dataset_path = dataset_path
        self.filename_list = [i for i in os.listdir(os.path.join(dataset_path, 'label')) if i.endswith('.json')]
        self.label_idx = {'figure': 1, 'hand_small': 2, 'hand_large': 3}
        if shuffle:
            np.random.shuffle(self.filename_list)
        self.num_classes = num_classes


    def __len__(self):
        return len(self.filename_list)

    
    def __getitem__(self, index):
        filename = self.filename_list[index]
        json_filename = self.dataset_path + '/label/' + filename
        with open(json_filename, 'r') as f:
            label_ = json.load(f)
            # img_filename = label_['path']
            img_filename = self.dataset_path + '/image/' + os.path.splitext(filename)[0] + '.jpg'
            assert os.path.isfile(img_filename)
            outputs = label_['outputs']
            image = np.array(Image.open(img_filename).convert('RGB'))
            im_height, im_width = image.shape[:2]
            objects = outputs['object']
            gt_boxes = []
            gt_category = []
            for k in objects:
                gt_boxes.append([k['bndbox']['xmin']/im_width, k['bndbox']['ymin']/im_height, k['bndbox']['xmax']/im_width, k['bndbox']['ymax']/im_height])
                gt_category.append([self.label_idx[k['name']]])
            return image, np.array(gt_boxes), np.array(gt_category)


    def create_tf_dataset(self, normalized=True, one_hot=False):
        def gen():
            for filename in self.filename_list:
                json_filename = self.dataset_path + '/label/' + filename
                with open(json_filename, 'r') as f:
                    label_ = json.load(f)
                    # img_filename = label_['path']
                    img_filename = self.dataset_path + '/image/' + os.path.splitext(filename)[0] + '.jpg'
                    assert os.path.isfile(img_filename)
                    outputs = label_['outputs']
                    image = np.array(Image.open(img_filename).convert('RGB'))
                    objects = outputs['object']
                    gt_boxes = []
                    gt_category = []
                    for k in objects:
                        gt_boxes.append([k['bndbox']['xmin'], k['bndbox']['ymin'], k['bndbox']['xmax'], k['bndbox']['ymax']])
                        gt_category.append([self.label_idx[k['name']]])
                    if normalized:
                        im_height, im_width = image.shape[:2]
                        gt_boxes = [[i[0]/im_width, i[1]/im_height, i[2]/im_width, i[3]/im_height] for i in gt_boxes]
                    if one_hot:
                        pass
                # yield image, np.array(gt_boxes).reshape([len(gt_boxes),4]), np.array(gt_category)
                yield image, np.array(gt_boxes[0]).reshape([1,4]), np.array(gt_category[0])
        dataset = tf.data.Dataset.from_generator(generator=gen, output_types=(tf.float32, tf.float32, tf.int32))
        return dataset

    
if __name__ == '__main__':
    train_dataset = HandDetectionDataset(dataset_path='/media/wei/Memory/ssd-dataset/train')
    print(len(train_dataset))

    train_dataset_tf = train_dataset.create_tf_dataset(normalized=True)
    for index, item in enumerate(train_dataset_tf,1):
        print('Sample:#{} image:{} Box:{} Category:{}'.format(index, item[0].shape,item[1].shape,item[2].shape))

    ds = train_dataset_tf.batch(batch_size=4)
    for item in ds:
        print(item)
